package com.ml4ai.demo;

import com.ml4ai.nn.*;
import com.ml4ai.nn.core.Toolkit;
import com.ml4ai.nn.core.Variable;
import com.ml4ai.nn.core.optimizers.Moment;
import com.ml4ai.nn.core.optimizers.NNOptimizer;
import lombok.Getter;
import lombok.SneakyThrows;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;

import javax.imageio.ImageIO;
import java.awt.*;
import java.awt.image.BufferedImage;
import java.io.FileOutputStream;

/**
 * 生成式对抗网络
 */
public class GenerativeAdversarialNet {

    @Getter
    public static class GAN extends BaseForwardNetwork {

        private ForwardNetwork generator;
        private ForwardNetwork discriminator;

        public GAN() {
            /**
             * 生成器定义
             */
            generator = new SequentialForward(
                    new Linear(3, 128),
                    new Relu(),
                    new Linear(128, 256),
                    new Relu(),
                    new Linear(256, 28 * 28)
            );
            /**
             * 判别器
             */
            discriminator = new SequentialForward(
                    new Linear(28 * 28, 256),
                    new Relu(),
                    new Linear(256, 128),
                    new Relu(),
                    new Linear(128, 1),
                    new Sigmoid()
            );
            add(generator);
            add(discriminator);
        }

        @Override
        public Variable[] getParameters() {
            return super.getParameters();
        }

        public Variable[] getDiscriminatorParameters() {
            return discriminator.getParameters();
        }

        public Variable[] getGeneratorParameters() {
            return generator.getParameters();
        }
    }

    private static DataSetIterator dataSetIt = null;

    static {
        try {
            dataSetIt = new MnistDataSetIterator(1, 1000);
        } catch (Exception e) {
            System.out.println("数据集获取失败");
        }
    }

    public static INDArray takeSample() {
        INDArray data = dataSetIt.next(10).getFeatures();
        return data;
    }

    @SneakyThrows
    public static void main(String[] argv) {
        GAN gan = new GAN();
        NNOptimizer generatorOptimizer = new Moment(gan.getGeneratorParameters(), 1e-2, 0.95D);
        NNOptimizer discriminatorOptimizer = new Moment(gan.getDiscriminatorParameters(), 1e-2, 0.95D);

        Toolkit tool = new Toolkit();
        for (int i = 0; i < 10000; i++) {
            Variable seed = new Variable(Nd4jUtil.rand(new int[]{10, 3}));
            Variable gz = gan.generator.forward(seed)[0];
            Variable x = new Variable(takeSample());
            Variable dx = gan.discriminator.forward(x)[0].mean();
            Variable dgz = gan.discriminator.forward(gz)[0].mean();
            Variable discriminator_target = dx.log(Math.E).add(new Variable(1).sub(dgz).log(Math.E));
            Variable generator_target = new Variable(1).sub(dgz).log(Math.E);
            Variable d_loss = new Variable(0).sub(discriminator_target);
            Variable g_loss = generator_target;
            tool.grad2zero(d_loss);
            tool.backward(d_loss);
            discriminatorOptimizer.update();
            tool.grad2zero(g_loss);
            tool.backward(g_loss);
            generatorOptimizer.update();
            if (i % 100 == 0) {
                System.out.println("loss:" + g_loss.data.scalar);
                new GenerativeAdversarialNet().paint(i, gz);
                new GenerativeAdversarialNet().paint(i + 1, x);
            }

        }
    }


    private void paint(int i, Variable var) throws Exception {
        double[] data = var.data.tensor.data().asDouble();
        BufferedImage bi = new BufferedImage(28, 28, BufferedImage.TYPE_INT_BGR);
        Graphics g = bi.getGraphics();

        for (int l = 0; l < 28; l++) {
            for (int h = 0; h < 28; h++) {
                int p = (int) (data[l * 28 + h] * 255);
                if (p > 255) p = 255;
                if (p < 0) p = 0;
                g.setColor(new Color(p, p, p));
                g.drawLine(l, h, l, h);
            }
        }
        ImageIO.write(bi, "jpg", new FileOutputStream("d:\\doc\\data\\" + i + ".jpg"));
    }

}

